#!/usr/bin/env python
# -*- coding: utf-8 -*-
import pathlib
import pickle
from datetime import datetime, time, date
from itertools import repeat
from multiprocessing import Pool
from typing import Dict, List, Tuple
from functools import partial
import pandas as pd
from absl import app, flags
from easydict import EasyDict
from rqalpha import run_func
from rqalpha.utils.logger import user_log
from ruamel import yaml

from quant import strategy, utils
from quant.trader import trader, rqalpha
from quant.vlogger import VLogger

flags.DEFINE_string('config', None, 'Backtest config')
flags.DEFINE_string('strategy', None, 'Strategy in config')
flags.DEFINE_string('dt', None, 'Date string')
flags.DEFINE_bool('cache', True, 'Use cached result')
flags.DEFINE_bool('benchmark', True, 'Run benchmark')
# Example usage:
# ./quant/backtest/backtest.py --config="conbond" --strategy='double_low' \
# --params_vars='{"threshold": [130], "rebalance": ["周"], "days_to_maturity": [10], "days_to_stop_trading": [10]}'
# ./quant/backtest/backtest.py --config="conbond" --strategy='double_low_rank' \
# --params_vars='{"threshold": [130], "rebalance": ["周"], "days_to_maturity": [10], "days_to_stop_trading": [10], \
# "weight": [{"weight_price": 0.2, "weight_cpr": 0.8}]}'
flags.DEFINE_string('params_vars', None, 'Used to run single conbond strategy')
flags.DEFINE_integer('v', 0, 'Verbose level')

FLAGS = flags.FLAGS


def run_strategy(backtest_dir: pathlib.Path, st: strategy.TradeStrategy,
                 config: EasyDict) -> Dict:
    p = backtest_dir.joinpath('%s.pkl' % st.description)
    VLogger.info('Start backtesting: %s' % st.description)
    try:
        r = run_func(init=st.init,
                     handle_bar=st.handle_bar,
                     config=dict(config))['sys_analyser']
        VLogger.info('Done backtesting: %s' % st.description)
        r['summary']['strategy_name'] = st.description
        if len(st.signals) > 0:
            r['signals'] = pd.concat(st.signals)
        with p.open(mode='wb') as f:
            pickle.dump(r, f)
        return r
    except Exception as e:
        pd.concat(st.signals).to_csv('/tmp/%s.csv' % st.description)
        VLogger.error(str(e))
        raise e


def gen_vars_dict(pdict: Dict) -> pd.DataFrame:
    df_params = None
    for k, v in pdict.items():
        df = pd.DataFrame.from_dict({k: v})
        if isinstance(v[0], dict):
            df[list(v[0].keys())] = df.apply(lambda r: tuple(r[k].values()),
                                             axis=1,
                                             result_type='expand')
            df.drop(columns=[k], inplace=True)
        if df_params is None:
            df_params = df
        else:
            df_params = df_params.merge(df, how='cross')
    return df_params


def prepare_data(bt: EasyDict) -> Tuple[pd.DataFrame, pd.DataFrame]:
    cache_dir = pathlib.Path(bt.cache_dir)
    all_instruments = []
    for fins in bt.instruments:
        instruments = pd.read_csv(cache_dir.joinpath(fins),
                                  index_col=['order_book_id'])
        # Remove not listed, nan will cause rqalpha to fail
        instruments = instruments[~pd.isna(instruments.listed_date)]
        #  instruments.fillna({'de_listed_date': '2099-12-31'}, inplace=True)
        all_instruments.append(instruments)
    df_instruments = pd.concat(all_instruments)
    #  df_instruments[['de_listed_date']].to_csv('/tmp/debug.csv')

    dfs = []
    for k, v in bt.bars.items():
        p = cache_dir.joinpath(k).expanduser()
        for f in p.parent.glob(p.name):
            df = pd.read_csv(f, parse_dates=['datetime'])
            df = df[df.datetime.dt.date.astype(str) >= v.start]
            df = df[df.datetime.dt.date.astype(str) <= v.end]
            if 'yield_to_maturity' in df.columns:
                # e.g. 127013.XSHE
                df = df[~pd.isna(df.yield_to_maturity)]
            df.fillna({'volume': 0}, inplace=True)
            dfs.append(df)
    bars = pd.concat(dfs)
    assert not bars.empty
    if bt.vars.frequency == '1d':
        bars['datetime'] = bars.datetime.apply(
            lambda dt: datetime.combine(dt.date(), time(15, 0, 0)))
    bars.sort_values(['datetime', 'order_book_id'], inplace=True)
    bars.set_index(['datetime', 'order_book_id'], inplace=True)
    #  bars.to_csv(bt.backtest_dir.joinpath('bars.csv'))
    #  df_instruments.to_csv(bt.backtest_dir.joinpath('instruments.csv'))
    return bars, df_instruments


def get_conbond_data(instruments: pd.DataFrame, bars: pd.DataFrame, context,
                     _):
    return bars.loc(axis=0)[pd.IndexSlice[context.now, :]].join(instruments[[
        'bond_type', 'symbol', 'de_listed_date', 'listed_date'
    ]])


def get_conbond_data2(context, bar_dict, td: trader.Trader):
    instruments = td.all_instruments(ins_type='CONVERTIBLE', dt=context.now)
    instruments = pd.concat([ins1, ins2])
    instruments.set_index('order_book_id', inplace=True)
    bars = []
    for order_book_id in instruments.index.tolist():
        bar = bar_dict[order_book_id]
        bars.append(bar._data)
    bars = pd.DataFrame(bars)
    bars.set_index(['datetime', 'order_book_id'], inplace=True)

    return bars.join(instruments[[
        'bond_type', 'symbol', 'de_listed_date', 'listed_date'
    ]])


def make_tasks(bt: EasyDict) -> List[strategy.TradeStrategy]:
    cwd = pathlib.Path(__file__).parent
    with cwd.joinpath('rqcfg.yml').open(mode='r') as f:
        config = EasyDict(yaml.safe_load(f))
    for k in config.base.keys():
        if config.base[k] is None:
            config.base[k] = bt.vars[k]
    config.mod_configs = {}

    bars, instruments = prepare_data(bt)
    bt.vars['conbond_data_fn'] = partial(get_conbond_data, instruments, bars)
    bt.vars['trader'] = rqalpha.RqTrader(bt.vars.frequency)
    config.mod.local_data_source.bars = bars
    config.mod.local_data_source.instruments = instruments

    tasks = []
    for name, task in bt.strategies.items():
        for var in task.params:
            if task.params[var] is None:
                task.params[var] = bt.vars[var]
        if 'params_vars' in task:
            for vars_dict in gen_vars_dict(
                    task.params_vars).to_dict('records'):
                task.params.cfg = strategy.ConbondStrategy.strategy_cfg(
                    name, vars_dict)
                st = getattr(strategy, task.cls)(**task.params)
                tasks.append(st)
        else:
            task.params.cfg = strategy.ConbondStrategy.strategy_cfg(name, {})
            st = getattr(strategy, task.cls)(**task.params)
            tasks.append(st)
    bt.config = config
    return tasks


def main(_):
    import logging
    logging.getLogger().setLevel(logging.INFO)
    VLogger.v = FLAGS.v
    cwd = pathlib.Path(__file__).parent
    with cwd.joinpath('backtest.yml').open(mode='r') as f:
        docs = yaml.load(f, Loader=yaml.UnsafeLoader)
    for config in docs.keys() if FLAGS.config is None else [FLAGS.config]:
        assert config in docs
        bt = EasyDict(docs[config])
        VLogger.info('Backtesting config: %s' % config)
        if FLAGS.dt is not None:
            assert FLAGS.strategy is not None
            stcfg = bt.strategies[FLAGS.strategy]
            bars, instruments = prepare_data(bt)
            assert FLAGS.params_vars is not None
            stcfg.params.cfg = strategy.ConbondStrategy.strategy_cfg(
                FLAGS.strategy, eval(FLAGS.params_vars))
            st = getattr(strategy, stcfg.cls)(**stcfg.params)
            dt = datetime.combine(date.fromisoformat(FLAGS.dt), time(15, 0, 0))
            VLogger.v = 2
            st.generate_signal(
                get_conbond_data(instruments, bars, EasyDict({'now': dt}),
                                 None), dt)
        else:
            bt.backtest_dir = pathlib.Path('/data/quant/backtest').joinpath(
                config)
            if FLAGS.strategy is not None:
                strategies = {FLAGS.strategy: bt.strategies[FLAGS.strategy]}
                if FLAGS.params_vars is not None:
                    strategies[FLAGS.strategy].params_vars = eval(
                        FLAGS.params_vars)
                for k, v in bt.strategies.items():
                    if FLAGS.benchmark and v.cls == 'Benchmark':
                        strategies[k] = v
                bt.strategies = strategies
            tasks = make_tasks(bt)
            tasks_to_run = []
            results = []
            for task in tasks:
                p = bt.backtest_dir.joinpath('%s.pkl' % task.description)
                if FLAGS.cache and p.exists():
                    # and task.__class__.__name__ != 'Benchmark':
                    VLogger.info('Cached: %s' % task.description)
                    with p.open(mode='rb') as f:
                        r = pickle.load(f)
                    results.append(r)
                else:
                    tasks_to_run.append(task)
            for task in tasks_to_run:
                VLogger.info('Backtest Task: %s' % task.description)
            VLogger.info("Number of tasks: %s" % len(tasks_to_run))
            if len(tasks_to_run) > 0:
                import click
                if click.confirm('Continue?'):
                    VLogger.logger = user_log
                    if bt.parallel is None:
                        bt.parallel = len(bt.strategies.values())
                    with Pool(bt.parallel) as pool:
                        results += pool.starmap(
                            run_strategy,
                            zip(repeat(bt.backtest_dir), tasks_to_run,
                                repeat(bt.config)))
            utils.plot_rqalpha_backtest_results(
                datetime.now(),
                bt.vars.start_date,
                bt.vars.end_date,
                {r['summary']['strategy_name']: r
                 for r in results},
                savefile=cwd.joinpath('%s.png' % config))


if __name__ == '__main__':
    app.run(main)
